# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import gmpy2
from gmpy2 import mpq, mpz, mpfr, f2q
from hysop.constants import (
HYSOP_REAL,
HYSOP_INTEGER,
HYSOP_INDEX,
HYSOP_BOOL,
HYSOP_COMPLEX,
)
MPQ = mpq(0).__class__
MPZ = mpz(0).__class__
MPFR = mpfr(0).__class__
F2Q = f2q(0).__class__
def _mpqize(x):
if isinstance(x, int):
return mpq(x, 1)
elif isinstance(x, float):
return f2q(x)
else:
return mpq(str(x))
mpqize = np.vectorize(_mpqize)
[docs]
def get_dtype(x):
if isinstance(x, np.dtype):
return x.type
elif hasattr(x, "dtype"):
if callable(x.dtype):
return x.dtype()
elif (
x.dtype.__class__.__name__ == "getset_descriptor"
): # dtype.type has a dtype field...
return x
else:
return x.dtype
elif isinstance(x, int):
return np.int64
elif isinstance(x, float):
return np.float64
elif isinstance(x, complex):
return np.complex128
elif x is None:
return None
else:
msg = "Unknown type in get_dtype (got {})."
msg = msg.format(x.__class__)
raise TypeError(msg)
[docs]
def get_itemsize(x):
dtype = np.dtype(get_dtype(x))
return dtype.itemsize
[docs]
def is_fp(x):
types = (np.float16, np.float32, np.float64, np.longdouble)
return get_dtype(x) in types
[docs]
def is_signed(x):
types = (np.int8, np.int16, np.int32, np.int64)
return get_dtype(x) in types
[docs]
def is_unsigned(x):
types = (np.bool_, np.uint8, np.uint16, np.uint32, np.uint64)
return get_dtype(x) in types
[docs]
def is_integer(x):
return is_signed(x) or is_unsigned(x)
[docs]
def is_complex(x):
types = (np.complex64, np.complex128, np.clongdouble)
return get_dtype(x) in types
[docs]
def default_invalid_value(dtype):
nan = float("nan")
if is_complex(dtype):
return 1.0 * nan + 1.0j * nan
elif is_fp(dtype):
return nan
elif is_unsigned(dtype):
return 0
elif is_signed(dtype):
return 0
else:
raise NotImplementedError
# promote_dtype
[docs]
def match_dtype(x, dtype):
"""Promote x.dtype to dtype (always safe cast)."""
xtype = get_dtype(x)
if isinstance(dtype, str):
if dtype == "f":
return np.promote_types(xtype, np.float16)
elif dtype == "i":
return np.promote_types(xtype, np.int8)
elif dtype == "u":
return np.promote_types(xtype, np.uint8)
elif dtype == "b":
return np.promote_types(xtype, HYSOP_BOOL)
elif dtype == "c":
return np.promote_types(xtype, np.complex64)
else:
raise NotImplementedError(dtype)
elif xtype is None:
return dtype
elif dtype is None:
return xtype
else:
return dtype
[docs]
def demote_dtype(x, dtype):
"""Demote x.dtype to dtype (not a safe cast)."""
xtype = get_dtype(x)
n = xtype(0).itemsize
if is_complex(xtype):
n //= 2
if isinstance(dtype, str):
if dtype == "c":
return {
1: np.complex64,
2: np.complex64,
4: np.complex64,
8: np.complex128,
16: np.clongdouble,
}[n]
elif dtype == "f":
return {
1: np.float16,
2: np.float16,
4: np.float32,
8: np.float64,
16: np.longdouble,
}[n]
elif dtype == "i":
return {1: np.int8, 2: np.int16, 4: np.int32, 8: np.int64}[n]
elif dtype == "u":
return {1: np.uint8, 2: np.uint16, 4: np.uint32, 8: np.uint64}[n]
else:
raise NotImplementedError(dtype)
elif xtype is None:
return dtype
elif dtype is None:
return xtype
else:
return dtype
[docs]
def match_float_type(x):
return match_dtype(x, "f")
[docs]
def match_signed_type(x):
return match_dtype(x, "i")
[docs]
def match_unsigned_type(x):
return match_dtype(x, "i")
[docs]
def match_complex_type(x):
return match_dtype(x, "c")
[docs]
def match_bool_type(x):
return match_dtype(x, "b")
[docs]
def complex_to_float_dtype(dtype):
dtype = get_dtype(dtype)
assert is_complex(dtype)
if dtype == np.complex64:
return np.float32
elif dtype == np.complex128:
return np.float64
elif dtype == np.clongdouble:
return np.longdouble
else:
msg = msg.format(dtype)
msg = "Unknown complex type {}."
raise RuntimeError(msg)
[docs]
def float_to_complex_dtype(dtype):
dtype = get_dtype(dtype)
assert is_fp(dtype), f"{dtype} is not a float"
if dtype == np.float32:
return np.complex64
elif dtype == np.float64:
return np.complex128
elif dtype == np.longdouble:
return np.clongdouble
else:
msg = "Unknown float type {}."
msg = msg.format(dtype)
raise RuntimeError(msg)
[docs]
def determine_fp_types(dtype):
if is_fp(dtype):
ftype = dtype
ctype = float_to_complex_dtype(ftype)
elif is_complex(dtype):
ctype = dtype
ftype = complex_to_float_dtype(ctype)
else:
msg = "{} is not a floating point or complex data type."
msg = msg.format(dtype)
raise ValueError(msg)
return (np.dtype(ftype), np.dtype(ctype))
[docs]
def find_common_dtype(*args):
dtypes = tuple(get_dtype(arg) for arg in args)
itemsize = tuple(get_itemsize(x) for x in dtypes)
n = max(itemsize)
if any(is_complex(x) for x in dtypes):
return {8: np.complex64, 16: np.complex128, 32: np.clongdouble}[n]
elif any(is_fp(x) for x in dtypes):
return {2: np.float16, 4: np.float32, 8: np.float64, 16: np.longdouble}[n]
elif any(is_signed(x) for x in dtypes):
return {1: np.int8, 2: np.int16, 4: np.int32, 8: np.int64}[n]
elif any(is_unsigned(x) for x in dtypes):
return {1: np.uint8, 2: np.uint16, 4: np.uint32, 8: np.uint64}[n]
else:
msg = "Did not find any matching dtype."
raise NotImplementedError(msg)